#! /usr/bin/env python3

############################################################################
#   Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
#   SPDX-License-Identifier: MIT
############################################################################

import argparse
import sys
import yaml

MODE_HPP="hpp"
MODE_CPP="cpp"
MODES=(MODE_HPP, MODE_CPP)

TAG_NAME="Name"
TAG_ARGS="Args"
TAG_COPYRIGHT="Copyright"
TAG_ENUM_NAME="EnumName"
TAG_EXTERN_NAME="ExternName"
TAG_FLAGS="Flags"
TAG_FMRB="Fmrb"
TAG_LLVM_NAME="LlvmName"
TAG_MODREF="ModRef"

FLAG_DUMMY="Dummy"
FLAG_NANOTUBE="Nanotube"
FLAG_LLVM="LLVM"

###########################################################################

HPP_BEGIN="""
#ifndef INTRINSIC_DEFS_HPP
#define INTRINSIC_DEFS_HPP

#include <map>
#include <string>
#include <vector>

#include "llvm/Analysis/AliasAnalysis.h"

namespace nanotube
{"""

ENUM_BEGIN="""
  namespace Intrinsics
  {
    typedef enum {
"""

ENUM_END="""
      // Not used.  Indicates the maximum value.
      end,
    } ID;
  } // namespace Intrinsics"""

HPP_END="""
  extern std::map<std::string, Intrinsics::ID> intrinsic_name_to_id;
  extern std::map<Intrinsics::ID, std::string> intrinsic_id_to_name;
  Intrinsics::ID intrinsic_id_from_llvm(llvm::Intrinsic::ID id);
  extern const std::vector<llvm::FunctionModRefBehavior> intrinsic_fmrb;
  extern const std::vector<std::vector<llvm::ModRefInfo> > intrinsic_arg_info;

} // namespace nanotube

#endif // INTRINSIC_DEFS_HPP
"""

###########################################################################

CPP_BEGIN="""
#include "IntrinsicDefs.hpp"

#include "llvm/IR/Intrinsics.h"

using namespace nanotube;
using namespace llvm;
"""

NAME_TO_ID_BEGIN="""
std::map<std::string, Intrinsics::ID> nanotube::intrinsic_name_to_id {
"""

NAME_TO_ID_END="""};
"""

ID_TO_NAME_BEGIN="""
std::map<Intrinsics::ID, std::string> nanotube::intrinsic_id_to_name {
"""

ID_TO_NAME_END="""};
"""

ID_FROM_LLVM_BEGIN="""
Intrinsics::ID nanotube::intrinsic_id_from_llvm(llvm::Intrinsic::ID id)
{
  switch (id)
  {
  case llvm::Intrinsic::not_intrinsic:
    return Intrinsics::none;
"""

ID_FROM_LLVM_END="""
  default:
    return Intrinsics::llvm_unknown;
  }
}
"""

FMRB_BEGIN="""
#define RO  FMRB_OnlyReadsArgumentPointees
#define RW  FMRB_OnlyAccessesArgumentPointees
#define RWI FMRB_OnlyAccessesInaccessibleOrArgMem
#define I   FMRB_OnlyAccessesInaccessibleMem
#define DNA FMRB_DoesNotAccessMemory
#define UK  FMRB_UnknownModRefBehavior
/* Note: There is a lack of precision here in multiple dimensions.
 * Missing write-only, the distinction between inaccessible mem vs arg
 * mem, and read-only inaccessible + arg */
const std::vector<llvm::FunctionModRefBehavior> nanotube::intrinsic_fmrb {
"""

FMRB_END="""};
#undef RO
#undef RW
#undef RWI
#undef I
#undef DNA
#undef UK
"""

ARG_INFO_BEGIN="""
/* This is a very compact representation of the ModRef behaviour of the
 * intrinsics.  It is a table indexed by the intrinsic ID and the argument
 * index.  Its content is the ModRefInfo for the argument.  Note that
 * non-pointer arguments are represented as NoModRef and arguments not
 * present are represented as ModRef. */
#define R  ModRefInfo::MustRef
#define W  ModRefInfo::MustMod
#define RW ModRefInfo::MustModRef
#define N  ModRefInfo::NoModRef
#define _  ModRefInfo::ModRef
const std::vector<std::vector<ModRefInfo> > nanotube::intrinsic_arg_info = {
"""

ARG_INFO_END="""};
#undef R
#undef W
#undef RW
#undef N
#undef _
"""

CPP_END="""
"""

###########################################################################

class app:
  def __init__(self, argv):
    self.__copyright = None
    self.__intrinsics = []

    p = argparse.ArgumentParser()

    p.add_argument("-C", "--cpp", dest="mode",
                   action="store_const", const=MODE_CPP,
                   help="Output a C++ implementation file.")
    p.add_argument("-H", "--hpp", dest="mode",
                   action="store_const", const=MODE_HPP,
                   help="Output a C++ header file.")
    p.add_argument("-m", "--mode", action="store", default=None,
                   help="Set the output mode.")

    p.add_argument("input", nargs=1, action="store",
                   help="The input file to process.")
    p.add_argument("output", nargs=1, action="store",
                   help="The output file to generate.")

    self.__prog = argv[0]
    self.__args = p.parse_args(argv[1:])

    if self.__args.mode == None:
      sys.stderr.write("%s: Mode not specified.  Use -C or -H.\n" %
                       self.__prog)
      sys.exit(1)

    if self.__args.mode not in MODES:
      sys.stderr.write("%s: Mode %r not valid.  Should be one of %s.\n" %
                       (self.__prog, self.__args.mode,
                        ", ".join(repr(x) for x in MODES)))
      sys.exit(1)

  def read_input(self):
    fname = self.__args.input[0]
    with open(fname, "r") as fh:
      self.__yaml = yaml.safe_load(fh)
    for e in self.__yaml:
      if TAG_COPYRIGHT in e:
        self.__copyright = e[TAG_COPYRIGHT]
        continue

      if TAG_NAME not in e:
        sys.stderr.write("%s: Missing %r in definition %r.\n" %
                         (self.__prog, TAG_NAME, e))
        sys.exit(1)
      name = e[TAG_NAME]
      c_name = name.replace(".", "_")

      if TAG_ARGS not in e:
        sys.stderr.write("%s: Missing %r in definition %r.\n" %
                         (self.__prog, TAG_ARGS, e))
        sys.exit(1)

      if type(e[TAG_ARGS]) is not list:
        sys.stderr.write("%s: Incorrect %r type in definition of %r.\n" %
                         (self.__prog, TAG_ARGS, name))
        sys.exit(1)

      for i,arg in enumerate(e[TAG_ARGS]):
        if TAG_MODREF not in arg:
          sys.stderr.write("%s: Missing %r in argument %d of %r.\n" %
                           (self.__prog, TAG_MODREF, i, name))
          sys.exit(1)

      if TAG_FLAGS not in e:
        sys.stderr.write("%s: Missing %r in definition %r.\n" %
                         (self.__prog, TAG_FLAGS, e))
        sys.exit(1)

      if TAG_FMRB not in e:
        sys.stderr.write("%s: Missing %r in definition %r.\n" %
                         (self.__prog, TAG_FMRB, e))
        sys.exit(1)

      if e[TAG_FLAGS] is None:
        e[TAG_FLAGS] = []
      else:
        e[TAG_FLAGS] = e[TAG_FLAGS].split()

      if TAG_ENUM_NAME not in e:
        if FLAG_LLVM in e[TAG_FLAGS]:
          e[TAG_ENUM_NAME] = "llvm_" + c_name
        else:
          e[TAG_ENUM_NAME] = c_name

      if TAG_LLVM_NAME not in e:
        if FLAG_LLVM in e[TAG_FLAGS]:
          e[TAG_LLVM_NAME] = c_name

      if TAG_EXTERN_NAME not in e:
        if FLAG_NANOTUBE in e[TAG_FLAGS]:
          e[TAG_EXTERN_NAME] = "nanotube_"+e[TAG_ENUM_NAME]
        else:
          e[TAG_EXTERN_NAME] = e[TAG_ENUM_NAME]

      self.__intrinsics.append(e)

  def write_copyright(self, fh):
    fh.write(self.__copyright)

  def write_enum(self, fh):
    fh.write(ENUM_BEGIN)
    for e in self.__intrinsics:
      fh.write("      %s,\n" %(e[TAG_ENUM_NAME]))
    fh.write(ENUM_END)

  def write_hpp(self, fh):
    self.write_copyright(fh)
    fh.write(HPP_BEGIN)
    self.write_enum(fh)
    fh.write(HPP_END)

  def write_name_to_id(self, fh):
    fh.write(NAME_TO_ID_BEGIN)
    for e in self.__intrinsics:
      if FLAG_NANOTUBE not in e[TAG_FLAGS]:
        continue
      fh.write("  {\"%s\", Intrinsics::%s},\n" %
               (e[TAG_EXTERN_NAME], e[TAG_ENUM_NAME]))
    fh.write(NAME_TO_ID_END)

  def write_id_to_name(self, fh):
    fh.write(ID_TO_NAME_BEGIN)
    for e in self.__intrinsics:
      fh.write("  {Intrinsics::%s, \"%s\"},\n" %
               (e[TAG_ENUM_NAME], e[TAG_EXTERN_NAME]))
    fh.write(ID_TO_NAME_END)

  def write_id_from_llvm(self, fh):
    fh.write(ID_FROM_LLVM_BEGIN)
    for e in self.__intrinsics:
      if FLAG_LLVM in e[TAG_FLAGS]:
        fh.write("\n"
                 "  case llvm::Intrinsic::%s:\n"
                 "    return Intrinsics::%s;\n" %
                 (e[TAG_LLVM_NAME], e[TAG_ENUM_NAME]))
    fh.write(ID_FROM_LLVM_END)

  def write_fmrb(self, fh):
    fh.write(FMRB_BEGIN)
    for e in self.__intrinsics:
      fh.write("  %-6s // %s\n" %
               (e[TAG_FMRB]+",", e[TAG_ENUM_NAME]))
    fh.write(FMRB_END)

  def write_arg_info(self, fh):
    fh.write(ARG_INFO_BEGIN)
    for e in self.__intrinsics:
      args = e[TAG_ARGS]
      arg_info = ",".join("%2s"%e[TAG_MODREF] for e in args)
      fh.write("  %-35s // %s\n" %
               ("{"+arg_info+"},", e[TAG_ENUM_NAME]))

    fh.write(ARG_INFO_END)

  def write_cpp(self, fh):
    self.write_copyright(fh)
    fh.write(CPP_BEGIN)
    self.write_name_to_id(fh)
    self.write_id_to_name(fh)
    self.write_id_from_llvm(fh)
    self.write_fmrb(fh)
    self.write_arg_info(fh)
    fh.write(CPP_END)

  def write_output(self):
    fname = self.__args.output[0]
    with open(fname, "w") as fh:
      if self.__args.mode == MODE_HPP:
        self.write_hpp(fh)
      elif self.__args.mode == MODE_CPP:
        self.write_cpp(fh)
      else:
        assert False

  def main(self):
    self.read_input()
    self.write_output()

app(sys.argv).main()
